-
Notifications
You must be signed in to change notification settings - Fork 26.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama et al. / FSDP : Fix breaking change in 4.40 for FSDP #31161
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing and apologies for breaking this!
Some questions before we can merge
- Would it make sense to add a test to make sure we don't accidentally break this again?
- Having
**kwargs
in theforward
method isn't standard amongst transformers models. Is there something special about these models which need this for FSDP? If not, should we be adding to other models? - Is there an alternative to using this injection? Relying on kwargs being passed isn't ideal
Thanks !
Yes, I'll add a test in this PR to test this behavior and catch bugs in the future!
Yes agreed, I think we should add it to all 'most-used' models. FSDP is useful for large models, so I would say we should add it for LLMs (llama, gemma, mistral, mixtral, gpt-neo, etc.) to make things consistent. Happy to do that within this PR !
I am not sure, this seems to be something internal to FSDP + CPU offloading, I don't think we can find a workaround to this :/ for me since it used to work before, it should be still working for future transformers versions to ensure BC. What do you think? |
Awesome - thank you!
Make sense - let's leave as-is :) |
@younesbelkada I'm really sorry I missed the rerequest for review. |
Remove script - FSDP + CPU offloading it tested in the test suite
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing @younesbelkada, and apologies for the delay in reviewing.
I was able to make the necessary updates to resolve conflicts with main
through the online editor. As this was just merging new input argument it didn't affect the structure of the PR. I did remove the testing_utils scripts (which I would have asked you to remove in a review :) )
What does this PR do?
Fixes: #30523
Click to see the snippet (make sure to run `accelerate config` and select FSDP options before hand and run `accelerate launch script.py`)
#30743 introduced a breaking change for users that use Llama-based models + FSDP + activation checkpointing with FSDP.
Before #30743 - we were able to pass arbitrary kwargs within Llama modules that were silently ignored. When doing FSDP + activation checkpointing, the target gradient checkpointing classes are wrapped in a new class, and additional kwargs are passed along that class forward pass
The script above used to work for transformers <= 4.40.0 and does not work anymore due to #30743 , re-intoducing kwargs in the foward method signature fixes the bug
cc @amyeroberts